# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import warnings

import jax.numpy as jnp
import numpy as np
from jax.lax import while_loop

import brainpy.math as bm
from brainpy.math.object_transform.base import BrainPyObject
from brainpy.types import ArrayType
from .utils import (Sigmoid,
                    Regularization,
                    L1Regularization,
                    L1L2Regularization,
                    L2Regularization,
                    polynomial_features,
                    normalize)

__all__ = [
    # brainpy_object class for offline training algorithm
    'OfflineAlgorithm',

    # training methods
    'LinearRegression', 'linear_regression',
    'RidgeRegression', 'ridge_regression',
    'LassoRegression',
    'LogisticRegression',
    'PolynomialRegression',
    'PolynomialRidgeRegression',
    'ElasticNetRegression',

    # general supports
    'get_supported_offline_methods',
    'register_offline_method',
]

name2func = dict()


class OfflineAlgorithm(BrainPyObject):
    """Base class for offline training algorithm."""

    def __init__(self, name=None):
        super(OfflineAlgorithm, self).__init__(name=name)

    def __call__(self, targets, inputs, outputs=None):
        """The training procedure.

        Parameters::

        targets: ArrayType
          The 2d target data with the shape of `(num_batch, num_output)`.
        inputs: ArrayType
          The 2d input data with the shape of `(num_batch, num_input)`.
        outputs: ArrayType
          The 2d output data with the shape of `(num_batch, num_output)`.

        Returns::

        weight: ArrayType
          The weights after fit.
        """
        return self.call(targets, inputs, outputs)

    def call(self, targets, inputs, outputs=None) -> ArrayType:
        """The training procedure.

        Parameters::

        inputs: ArrayType
          The 3d input data with the shape of `(num_batch, num_time, num_input)`,
          or, the 2d input data with the shape of `(num_time, num_input)`.

        targets: ArrayType
          The 3d target data with the shape of `(num_batch, num_time, num_output)`,
          or the 2d target data with the shape of `(num_time, num_output)`.

        outputs: ArrayType
          The 3d output data with the shape of `(num_batch, num_time, num_output)`,
          or the 2d output data with the shape of `(num_time, num_output)`.

        Returns::

        weight: ArrayType
          The weights after fit.
        """
        raise NotImplementedError('Must implement the __call__ function by the subclass itself.')

    def __repr__(self):
        return self.__class__.__name__


def _check_data_2d_atls(x):
    if x.ndim < 2:
        raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.')
    if x.ndim != 2:
        return bm.flatten(x, end_dim=-2)
    else:
        return x


class RegressionAlgorithm(OfflineAlgorithm):
    """ Base regression model. Models the relationship between a scalar dependent variable y and the independent
    variables X.

    Parameters::

    max_iter: int
      The number of training iterations the algorithm will tune the weights for.
    learning_rate: float
      The step length that will be used when updating the weights.
    """

    def __init__(
        self,
        max_iter: int = None,
        learning_rate: float = None,
        regularizer: Regularization = None,
        name: str = None
    ):
        super(RegressionAlgorithm, self).__init__(name=name)
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.regularizer = regularizer

    def initialize(self, *args, **kwargs):
        pass

    def init_weights(self, n_features, n_out):
        """ Initialize weights randomly [-1/N, 1/N] """
        limit = 1 / np.sqrt(n_features)
        return bm.random.uniform(-limit, limit, (n_features, n_out))

    def gradient_descent_solve(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))

        # initialize weights
        w = self.init_weights(inputs.shape[1], targets.shape[1])

        def cond_fun(a):
            i, par_old, par_new = a
            return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
                                   i < self.max_iter).value

        def body_fun(a):
            i, _, par_new = a
            # Gradient of regularization loss w.r.t w
            y_pred = inputs.dot(par_new)
            grad_w = jnp.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new)
            # Update the weights
            par_new2 = par_new - self.learning_rate * grad_w
            return i + 1, par_new, par_new2

        # Tune parameters for n iterations
        r = while_loop(cond_fun, body_fun, (0, w - 1e-8, w))
        return r[-1]

    def predict(self, W, X):
        return jnp.dot(X, W)


class LinearRegression(RegressionAlgorithm):
    """Training algorithm of least-square regression.

    Parameters::

    name: str
      The name of the algorithm.
    """

    def __init__(
        self,
        name: str = None,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = False,
    ):
        super(LinearRegression, self).__init__(name=name,
                                               max_iter=max_iter,
                                               learning_rate=learning_rate,
                                               regularizer=Regularization(0.))
        self.gradient_descent = gradient_descent

    def call(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))

        # solving
        if self.gradient_descent:
            return self.gradient_descent_solve(targets, inputs)
        else:
            weights = jnp.linalg.lstsq(inputs, targets)
            return weights[0]


linear_regression = LinearRegression()

name2func['linear'] = LinearRegression
name2func['lstsq'] = LinearRegression


class RidgeRegression(RegressionAlgorithm):
    """Training algorithm of ridge regression.

    Parameters::

    alpha: float
      The regularization coefficient.

      .. versionadded:: 2.2.0

    beta: float
      The regularization coefficient.

      .. deprecated:: 2.2.0
         Please use `alpha` to set regularization factor.

    name: str
      The name of the algorithm.
    """

    def __init__(
        self,
        alpha: float = 1e-7,
        beta: float = None,
        name: str = None,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = False,
    ):
        if beta is not None:
            warnings.warn(f"Please use 'alpha' to set regularization factor. "
                          f"'beta' has been deprecated since version 2.2.0.",
                          UserWarning)
            alpha = beta
        super(RidgeRegression, self).__init__(name=name,
                                              max_iter=max_iter,
                                              learning_rate=learning_rate,
                                              regularizer=L2Regularization(alpha=alpha))
        self.gradient_descent = gradient_descent

    def call(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))

        # solving
        if self.gradient_descent:
            return self.gradient_descent_solve(targets, inputs)
        else:
            temp = inputs.T @ inputs
            if self.regularizer.alpha > 0.:
                temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1])
            weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets)
            return weights

    def __repr__(self):
        return f'{self.__class__.__name__}(beta={self.regularizer.alpha})'


ridge_regression = RidgeRegression()

name2func['ridge'] = RidgeRegression


class LassoRegression(RegressionAlgorithm):
    """Lasso regression method for offline training.

    Parameters::

    alpha: float
      Constant that multiplies the L1 term. Defaults to 1.0.
      `alpha = 0` is equivalent to an ordinary least square.
    max_iter: int
      The maximum number of iterations.
    degree: int
      The degree of the polynomial that the independent variable X will be transformed to.
    name: str
      The name of the algorithm.
    """

    def __init__(
        self,
        alpha: float = 1.0,
        degree: int = 2,
        add_bias: bool = False,
        name: str = None,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = True,
    ):
        super(LassoRegression, self).__init__(name=name,
                                              max_iter=max_iter,
                                              learning_rate=learning_rate,
                                              regularizer=L1Regularization(alpha=alpha))
        self.gradient_descent = gradient_descent
        self.add_bias = add_bias
        assert gradient_descent
        self.degree = degree

    def call(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))

        # solving
        inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias))
        return super(LassoRegression, self).gradient_descent_solve(targets, inputs)

    def predict(self, W, X):
        X = _check_data_2d_atls(bm.as_jax(X))
        X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias))
        return super(LassoRegression, self).predict(W, X)


name2func['lasso'] = LassoRegression


class LogisticRegression(RegressionAlgorithm):
    """Logistic regression method for offline training.

    Parameters::

    learning_rate: float
        The step length that will be taken when following the negative gradient during
        training.
    gradient_descent: boolean
      True or false depending on if gradient descent should be used when training. If
      false then we use batch optimization by least squares.
    max_iter: int
      The number of iteration to optimize the parameters.
    name: str
      The name of the algorithm.
    """

    def __init__(
        self,
        learning_rate: float = .1,
        gradient_descent: bool = True,
        max_iter: int = 4000,
        name: str = None,
    ):
        super(LogisticRegression, self).__init__(name=name,
                                                 max_iter=max_iter,
                                                 learning_rate=learning_rate)
        self.gradient_descent = gradient_descent
        self.sigmoid = Sigmoid()

    def call(self, targets, inputs, outputs=None) -> ArrayType:
        # prepare data
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))
        if targets.shape[-1] != 1:
            raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ')
        targets = targets.flatten()

        # initialize parameters
        param = self.init_weights(inputs.shape[1], targets.shape[1])

        def cond_fun(a):
            i, par_old, par_new = a
            return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)),
                                   i < self.max_iter).value

        def body_fun(a):
            i, par_old, par_new = a
            # Make a new prediction
            y_pred = self.sigmoid(inputs.dot(par_new))
            if self.gradient_descent:
                # Move against the gradient of the loss function with
                # respect to the parameters to minimize the loss
                par_new2 = par_new - self.learning_rate * (y_pred - targets).dot(inputs)
            else:
                gradient = self.sigmoid.grad(inputs.dot(par_new))
                diag_grad = bm.zeros((gradient.size, gradient.size))
                diag = bm.arange(gradient.size)
                diag_grad[diag, diag] = gradient
                par_new2 = jnp.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot(
                    diag_grad.dot(inputs).dot(par_new) + targets - y_pred)
            return i + 1, par_new, par_new2

        # Tune parameters for n iterations
        r = while_loop(cond_fun, body_fun, (0, param + 1., param))
        return r[-1]

    def predict(self, W, X):
        return self.sigmoid(X @ W)


name2func['logistic'] = LogisticRegression


class PolynomialRegression(LinearRegression):
    def __init__(
        self,
        degree: int = 2,
        name: str = None,
        add_bias: bool = False,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = True,
    ):
        super(PolynomialRegression, self).__init__(name=name,
                                                   max_iter=max_iter,
                                                   learning_rate=learning_rate,
                                                   gradient_descent=gradient_descent)
        self.degree = degree
        self.add_bias = add_bias

    def call(self, targets, inputs, outputs=None):
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))
        inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
        return super(PolynomialRegression, self).call(targets, inputs)

    def predict(self, W, X):
        X = _check_data_2d_atls(bm.as_jax(X))
        X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias)
        return super(PolynomialRegression, self).predict(W, X)


name2func['polynomial'] = PolynomialRegression


class PolynomialRidgeRegression(RidgeRegression):
    def __init__(
        self,
        alpha: float = 1.0,
        degree: int = 2,
        name: str = None,
        add_bias: bool = False,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = True,
    ):
        super(PolynomialRidgeRegression, self).__init__(alpha=alpha,
                                                        name=name,
                                                        max_iter=max_iter,
                                                        learning_rate=learning_rate,
                                                        gradient_descent=gradient_descent)
        self.degree = degree
        self.add_bias = add_bias

    def call(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))
        inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)
        return super(PolynomialRidgeRegression, self).call(targets, inputs)

    def predict(self, W, X):
        X = _check_data_2d_atls(bm.as_jax(X))
        X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias)
        return super(PolynomialRidgeRegression, self).predict(W, X)


name2func['polynomial_ridge'] = PolynomialRidgeRegression


class ElasticNetRegression(RegressionAlgorithm):
    """

    Parameters:
    -----------
    degree: int
        The degree of the polynomial that the independent variable X will be transformed to.
    reg_factor: float
        The factor that will determine the amount of regularization and feature
        shrinkage.
    l1_ration: float
        Weighs the contribution of l1 and l2 regularization.
    n_iterations: float
        The number of training iterations the algorithm will tune the weights for.
    learning_rate: float
        The step length that will be used when updating the weights.
    """

    def __init__(
        self,
        alpha: float = 1.0,
        degree: int = 2,
        l1_ratio: float = 0.5,
        name: str = None,
        add_bias: bool = False,

        # parameters for using gradient descent
        max_iter: int = 1000,
        learning_rate: float = 0.001,
        gradient_descent: bool = True,
    ):
        super(ElasticNetRegression, self).__init__(
            name=name,
            max_iter=max_iter,
            learning_rate=learning_rate,
            regularizer=L1L2Regularization(alpha=alpha, l1_ratio=l1_ratio)
        )
        self.degree = degree
        self.add_bias = add_bias
        self.gradient_descent = gradient_descent
        assert gradient_descent

    def call(self, targets, inputs, outputs=None):
        # checking
        inputs = _check_data_2d_atls(bm.as_jax(inputs))
        targets = _check_data_2d_atls(bm.as_jax(targets))
        # solving
        inputs = normalize(polynomial_features(inputs, degree=self.degree))
        return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs)

    def predict(self, W, X):
        X = _check_data_2d_atls(bm.as_jax(X))
        X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias))
        return super(ElasticNetRegression, self).predict(W, X)


name2func['elastic_net'] = ElasticNetRegression


def get_supported_offline_methods():
    """Get all supported offline training methods."""
    return tuple(name2func.keys())


def register_offline_method(name: str, method: OfflineAlgorithm):
    """Register a new offline learning method.

    Parameters::

    name: str
      The method name.
    method: OfflineAlgorithm
      The function method.
    """
    if name in name2func:
        raise ValueError(f'"{name}" has been registered in offline training methods.')
    if not isinstance(method, OfflineAlgorithm):
        raise ValueError(f'"method" must be an instance {OfflineAlgorithm.__name__}, but we got {type(method)}')
    name2func[name] = method


def get(name: str) -> OfflineAlgorithm:
    """Get the training function according to the training method name."""
    if name not in name2func:
        raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n'
                         f'But we got {name}.')
    return name2func[name]
